{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[download this notebook here](https://github.com/HumanCompatibleAI/imitation/blob/master/docs/tutorials/3_train_gail.ipynb)\n",
    "# Train an Agent using Generative Adversarial Imitation Learning\n",
    "\n",
    "The idea of generative adversarial imitation learning is to train a discriminator network to distinguish between expert trajectories and learner trajectories.\n",
    "The learner is trained using a traditional reinforcement learning algorithm such as PPO and is rewarded for trajectories that make the discriminator think that it was an expert trajectory."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As usual, we first need an expert. Again, we download one from the HuggingFace model hub for convenience.\n",
    "\n",
    "Note that we use a variant of the CartPole environment from the seals package, which has fixed episode durations. Read more about why we do this [here](https://imitation.readthedocs.io/en/latest/main-concepts/variable_horizon.html)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from imitation.policies.serialize import load_policy\n",
    "from imitation.util.util import make_vec_env\n",
    "from imitation.data.wrappers import RolloutInfoWrapper\n",
    "\n",
    "SEED = 42\n",
    "\n",
    "env = make_vec_env(\n",
    "    \"seals:seals/CartPole-v0\",\n",
    "    rng=np.random.default_rng(SEED),\n",
    "    n_envs=8,\n",
    "    post_wrappers=[\n",
    "        lambda env, _: RolloutInfoWrapper(env)\n",
    "    ],  # needed for computing rollouts later\n",
    ")\n",
    "expert = load_policy(\n",
    "    \"ppo-huggingface\",\n",
    "    organization=\"HumanCompatibleAI\",\n",
    "    env_name=\"seals/CartPole-v0\",\n",
    "    venv=env,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We generate some expert trajectories, that the discriminator needs to distinguish from the learner's trajectories."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from imitation.data import rollout\n",
    "\n",
    "rollouts = rollout.rollout(\n",
    "    expert,\n",
    "    env,\n",
    "    rollout.make_sample_until(min_timesteps=None, min_episodes=60),\n",
    "    rng=np.random.default_rng(SEED),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we are ready to set up our GAIL trainer.\n",
    "Note, that the `reward_net` is actually the network of the discriminator.\n",
    "We evaluate the learner before and after training so we can see if it made any progress.\n",
    "\n",
    "First we construct a GAIL trainer ..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from imitation.algorithms.adversarial.gail import GAIL\n",
    "from imitation.rewards.reward_nets import BasicRewardNet\n",
    "from imitation.util.networks import RunningNorm\n",
    "from stable_baselines3 import PPO\n",
    "from stable_baselines3.ppo import MlpPolicy\n",
    "from stable_baselines3.common.evaluation import evaluate_policy\n",
    "\n",
    "learner = PPO(\n",
    "    env=env,\n",
    "    policy=MlpPolicy,\n",
    "    batch_size=64,\n",
    "    ent_coef=0.0,\n",
    "    learning_rate=0.0004,\n",
    "    gamma=0.95,\n",
    "    n_epochs=5,\n",
    "    seed=SEED,\n",
    ")\n",
    "reward_net = BasicRewardNet(\n",
    "    observation_space=env.observation_space,\n",
    "    action_space=env.action_space,\n",
    "    normalize_input_layer=RunningNorm,\n",
    ")\n",
    "gail_trainer = GAIL(\n",
    "    demonstrations=rollouts,\n",
    "    demo_batch_size=1024,\n",
    "    gen_replay_buffer_capacity=512,\n",
    "    n_disc_updates_per_round=8,\n",
    "    venv=env,\n",
    "    gen_algo=learner,\n",
    "    reward_net=reward_net,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "... then we evaluate it before training ..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "env.seed(SEED)\n",
    "learner_rewards_before_training, _ = evaluate_policy(\n",
    "    learner, env, 100, return_episode_rewards=True\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "... and train it ..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gail_trainer.train(200_000)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "... and finally evaluate it again."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "env.seed(SEED)\n",
    "learner_rewards_after_training, _ = evaluate_policy(\n",
    "    learner, env, 100, return_episode_rewards=True\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see that an untrained policy performs poorly, while GAIL matches expert returns (500):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\n",
    "    \"Rewards before training:\",\n",
    "    np.mean(learner_rewards_before_training),\n",
    "    \"+/-\",\n",
    "    np.std(learner_rewards_before_training),\n",
    ")\n",
    "print(\n",
    "    \"Rewards after training:\",\n",
    "    np.mean(learner_rewards_after_training),\n",
    "    \"+/-\",\n",
    "    np.std(learner_rewards_after_training),\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "bd378ce8f53beae712f05342da42c6a7612fc68b19bea03b52c7b1cdc8851b5f"
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
